import numpy as np
import pandas as pd
from tqdm import tqdm
import random
import os

from jepa_spec_trainer import Jepa_Spec
from jepa_trainer import Jepa

from data_utils import *

import torch
from torch.optim import AdamW, Adam
from torch.optim.lr_scheduler import LinearLR
from torch.utils.data import DataLoader

os.environ['PATH']

device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

######################################################################################################################################################################################
########################################################################## REPRODUCIBILITY ###########################################################################################
######################################################################################################################################################################################

torch.manual_seed(0)
random.seed(0)
np.random.seed(0)


######################################################################################################################################################################################
########################################################################## LOADING THE DATA ##########################################################################################
######################################################################################################################################################################################

csv = pd.read_csv("training_data_2.csv")
N_ITER = 40001
######################################################################################################################################################################################
##################################################################### MODEL INITIALIZATION ###########################################################################################
######################################################################################################################################################################################

ENC_NUM_BLOCKS = 4
DEC_NUM_BLOCKS = 2

ENC_DIM = 128
DEC_DIM = 128

NUM_HEADS = 4

PATCH_SIZE = 10
FS = 100
L = 10


NUM_STRIPS = 4
BATCH_SIZE = 250

MASKING_RATIO = 0.5


for use_spec in [1, 0]:

    for MASKING_RATIO in [0.4, 0.5, 0.6]:
        
        if use_spec == 1:
            
            weights_name = "jepa_spec_" + str(MASKING_RATIO)

            # Loading The Model
            trainer = Jepa_Spec(enc_num_blocks=ENC_NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=ENC_DIM, do_prob=0.1, patch_size=PATCH_SIZE, 
                                in_channels=1, fs=FS, l=L, dec_num_blocks=DEC_NUM_BLOCKS, dec_dim = DEC_DIM,
                                mask_ratio=MASKING_RATIO).to(device)


            optimizer = AdamW(trainer.parameters(), 
                        lr=1e-3, 
                        betas=(0.9, 0.95))


            
        else:

            
            
            weights_name = "jepa_" + str(MASKING_RATIO)

            # Loading The Model
            trainer = Jepa(enc_num_blocks=ENC_NUM_BLOCKS, num_heads=NUM_HEADS, model_dim=ENC_DIM, do_prob=0.1, patch_size=PATCH_SIZE, 
                                in_channels=1, fs=FS, l=L, dec_num_blocks=DEC_NUM_BLOCKS, dec_dim = DEC_DIM,
                                mask_ratio=MASKING_RATIO).to(device)


            optimizer = AdamW(trainer.parameters(), 
                        lr=1e-3, 
                        betas=(0.9, 0.95))

        print("\nTraining: " + str(weights_name))

        ds = SHHS_DataLoader(csv, num_strips = NUM_STRIPS, fs=FS, l = L)
        dl = DataLoader(ds, batch_size = BATCH_SIZE, shuffle=True, drop_last=True)

        EPOCHS = int(np.ceil(N_ITER / len(dl)))

        scheduler_1 = LinearLR(optimizer, total_iters=10, verbose=False)
        p_bar = tqdm(range(N_ITER))

        counter = 0
        rec_losses = []
        best_accu = 0

        for epoch in range(EPOCHS):
            for batch in iter(dl):

                trainer.train()
                optimizer.zero_grad()
                
                x1  = batch
                loss = trainer(x1)
                loss.backward()
                optimizer.step()
                trainer.update_moving_average()
                
                counter += 1

                if counter == N_ITER:
                    torch.save(trainer.to("cpu").state_dict(), weights_name + ".pth")
                    trainer.to(device)
                    break

                # p_bar.set_description("Loss : %s, Dis Loss : %s, Gra Loss : %s, Static Loss : %s" % (loss.item(), dissim_loss, gradual_loss, static_loss))
                p_bar.set_description("Loss : %s, Rec Loss: %s, Gra Loss: %s" % (loss.item()))
                p_bar.update(1)
                p_bar.refresh()
                
  
                if counter % 200 == 0:
                    scheduler_1.step()

                